{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchtext\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from torchtext.vocab import Vectors\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = torchtext.data.Field(include_lengths = False)\n",
    "label = torchtext.data.Field(sequential=False)\n",
    "train, val, test = torchtext.datasets.SST.splits(text, label, filter_pred=lambda ex: ex.label != 'neutral')\n",
    "text.build_vocab(train)\n",
    "label.build_vocab(train)\n",
    "train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits((train, val, test), batch_size=10, device=-1, repeat = False)\n",
    "# Build the vocabulary with word embeddings\n",
    "url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec'\n",
    "text.vocab.load_vectors(vectors=Vectors('wiki.simple.vec', url=url))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(nn.Module):\n",
    "    def __init__(self, input_size, num_classes, batch_size):\n",
    "        super(LogisticRegression, self).__init__()\n",
    "        self.linear = nn.Linear(input_size, num_classes, bias = True)\n",
    "        self.array_like = np.zeros((batch_size, input_size))\n",
    "    \n",
    "    def forward(self, x):\n",
    "        output = self.linear(x)\n",
    "        activated = torch.nn.functional.sigmoid(output)\n",
    "        return activated\n",
    "\n",
    "    def predict(self, x):\n",
    "        output = self.forward(x)\n",
    "        logits = torch.nn.functional.log_softmax(output, dim = 1)\n",
    "        return logits.max(1)[1] + 1\n",
    "\n",
    "    def binarize_occurrences(self, indices):\n",
    "        occurrences = self.array_like.copy()\n",
    "        for idx, entry in enumerate(indices): occurrences[idx][entry] = 1\n",
    "        return occurrences\n",
    "\n",
    "    def batch_to_input(self, batch, train = True):\n",
    "        word_indices = batch.text.data.numpy().T\n",
    "        x = self.binarize_occurrences(word_indices)\n",
    "        if train:\n",
    "            return Variable(torch.FloatTensor(x)), batch.label\n",
    "        else:\n",
    "            return Variable(torch.FloatTensor(x))\n",
    "    \n",
    "    def train(self, train_iter, val_iter, num_epochs, learning_rate = 1e-3):\n",
    "        criterion = torch.nn.NLLLoss()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)\n",
    "        loss_vec = []\n",
    "        \n",
    "        for epoch in range(1, num_epochs+1):\n",
    "            epoch_loss = 0\n",
    "            for batch in train_iter:\n",
    "                x, y = self.batch_to_input(batch, train = True)\n",
    "                \n",
    "                optimizer.zero_grad()\n",
    "                \n",
    "                output = self.forward(x)\n",
    "                \n",
    "                loss = criterion(output, y-1)\n",
    "                loss.backward()\n",
    "                \n",
    "                optimizer.step()\n",
    "                epoch_loss += loss.data[0]\n",
    "            self.model = model\n",
    "            \n",
    "            loss_vec.append(epoch_loss / len(train_iter))\n",
    "            if epoch % 1 == 0:\n",
    "                acc = self.validate(val_iter)\n",
    "                print('Epoch {} loss: {} | Valid acc: {}'.format(epoch, loss_vec[epoch-1], acc))\n",
    "        \n",
    "        plt.plot(range(len(loss_vec)), loss_vec)\n",
    "        plt.xlabel('Epoch')\n",
    "        plt.ylabel('Loss')\n",
    "        plt.show()\n",
    "        print('\\nModel trained.\\n')\n",
    "        self.loss_vec = loss_vec\n",
    "\n",
    "    def test(self, test_iter):\n",
    "        \"All models should be able to be run with following command.\"\n",
    "        upload, trues = [], []\n",
    "        # Update: for kaggle the bucket iterator needs to have batch_size 10\n",
    "        for batch in test_iter:\n",
    "            # Your prediction data here (don't cheat!)\n",
    "            x, y = self.batch_to_input(batch, train = False), batch.label\n",
    "            preds = self.predict(x)\n",
    "            upload += list(preds.data.numpy())\n",
    "            trues += list(y.data.numpy())\n",
    "            \n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(upload, trues)])\n",
    "        accuracy = correct / len(trues)\n",
    "        print('Test Accuracy:', accuracy)\n",
    "        \n",
    "        with open(\"predictions.txt\", \"w\") as f:\n",
    "            for u in upload:\n",
    "                f.write(str(u) + \"\\n\")\n",
    "                \n",
    "    def validate(self, val_iter):\n",
    "        y_p, y_t, correct = [], [], 0\n",
    "        for batch in val_iter:\n",
    "            x, y = self.batch_to_input(batch, train = False), batch.label\n",
    "            probs = self.model.predict(x)[:len(y.data.numpy())]\n",
    "            y_p += list(probs.data.numpy())\n",
    "            y_t += list(y.data.numpy())\n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(y_p, y_t)])\n",
    "        accuracy = correct / len(y_p)\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:37: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n",
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:84: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 loss: -0.8179293963950494 | Valid acc: 0.7729357798165137\n",
      "Epoch 2 loss: -0.9585080772982857 | Valid acc: 0.7694954128440367\n",
      "Epoch 3 loss: -0.9794496852538489 | Valid acc: 0.7706422018348624\n",
      "Epoch 4 loss: -0.9876801081303227 | Valid acc: 0.7752293577981652\n",
      "Epoch 5 loss: -0.9917574881473717 | Valid acc: 0.7821100917431193\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAAEKCAYAAAArYJMgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xt0lfWd7/H3N/cEAuESEiBcE6qCImKqgheCAr3YilZ7ceoMdLRWcaYzy9VOndN1ZtaZmZ5le+ZMe5wRLWqVTqt1qu2gra0ggmLx0mC5yE3CHSQh3AkhIZfv+WM/wU3YIYFk72cn+bzWelaey2/v/d2blXz4Pc9v/x5zd0REROIlJewCRESkZ1PQiIhIXCloREQkrhQ0IiISVwoaERGJKwWNiIjElYJGRETiSkEjIiJxpaAREZG4Sgu7gGQwePBgHz16dNhliIh0K6tWrTrg7vnttVPQAKNHj6a8vDzsMkREuhUz29mRdjp1JiIicaWgERGRuFLQiIhIXCloREQkrhQ0IiISVwoaERGJKwWNiIjElYKmE7YfOMH/enk9DU3NYZciIpK0FDSdsP1ADU//YQeLVn8UdikiIklLQdMJ0y8awsWFucxfXkFTs4ddjohIUlLQdIKZMW96CduqT7B4fWXY5YiIJCUFTSfdfNlQRg/KYf7yrbirVyMi0pqCppNSU4z7phWzbu9RVmw5EHY5IiJJJ5SgMbOBZrbEzLYEPwe00e4HZrbezDaa2SMWkWNmvzWzTcGxh6PazzWzajNbHSz3JOL93DZ5OIX9snh0WUUiXk5EpFsJq0fzELDU3ccBS4PtM5jZVOBaYCJwKfBJYFpw+F/d/WLgCuBaM/tM1EOfd/dJwfJkPN9Ei8y0VL5+w1je3X6I8h2HEvGSIiLdRlhBMxtYGKwvBG6N0caBLCADyATSgSp3r3X3ZQDufgp4HyiKe8XtuPOqEQzISWf+8q1hlyIiklTCCpoCd98XrFcCBa0buPvbwDJgX7C86u4bo9uYWR7weSK9oha3m9laM3vBzEa0VYCZ3Wtm5WZWXl1d3cm3AzkZaXzt2jG8vmk/Gz461unnExHpKeIWNGb2mpl9EGOZHd3OI0O1zhquZWYlwCVEeivDgRvN7Pqo42nAc8Aj7r4t2P0yMNrdJwJL+LjXdBZ3X+Dupe5emp/f7p1IO2TOlNH0yUjlsTfUqxERaRG3oHH3Ge5+aYxlEVBlZkMBgp/7YzzFbcA77l7j7jXA74ApUccXAFvc/UdRr3nQ3euDzSeBK+Px3trSPyedu6aM4rdrP2L7gROJfGkRkaQV1qmzl4A5wfocYFGMNruAaWaWZmbpRAYCbAQws38B+gN/G/2AlvAK3NLSPpHuvm4Maakp/Fi9GhERILygeRiYaWZbgBnBNmZWamYtI8VeALYC64A1wBp3f9nMioDvAuOB91sNY/5mMOR5DfBNYG7C3lFgSG4WXy4dwYvv72Hf0ZOJfnkRkaRj+jY7lJaWenl5eZc93+5DtZT963LmTBnNP3x+fJc9r4hIMjGzVe5e2l47zQwQByMG5jB70jCee28XB2vq23+AiEgPpqCJk/unFXOyoYlnVu4IuxQRkVApaOJkXEEun5pQwMKVOzhe1xB2OSIioVHQxNG8shKO1TXy83d3hV2KiEhoFDRxdPmIPK4fN5gnV2ynrqEp7HJEREKhoImzeWUlHKip55flu8MuRUQkFAqaOLtm7EAmj8zj8Te20dDUHHY5IiIJp6CJMzPjgekl7D1ykpdWfxR2OSIiCaegSYAbLx7CxYW5PPbGVpqb9QVZEeldFDQJYGbcX1ZMxf4aFm+oCrscEZGEUtAkyM2XDWXUoBzmL69A0/6ISG+ioEmQtNQU7ptWzNo9R3mr4kDY5YiIJIyCJoG+MHk4Bf0yeXRZRdiliIgkjIImgTLTUvn69WN5Z9shVu08FHY5IiIJoaBJsDuvGsmAnHTmL9ON0USkd1DQJFifzDTmTh3D0k372bjvWNjliIjEnYImBHOmjqJPRiqPLVevRkR6PgVNCPJyMrjrmlH8Zu1H7DhwIuxyRETiSkETkruvG0Naago/flO9GhHp2UILGjMbaGZLzGxL8HNAG+1+YGbrzWyjmT1iZhbsX25mm81sdbAMCfZnmtnzZlZhZu+a2ejEvauOG9Iviy+VFvHCqj1UHq0LuxwRkbgJs0fzELDU3ccBS4PtM5jZVOBaYCJwKfBJYFpUk6+6+6Rg2R/suxs47O4lwA+B78fxPXTKN24optnhiRXbwi5FRCRuwgya2cDCYH0hcGuMNg5kARlAJpAOtDdZWPTzvgDc1NILSjYjBuZwy+XDePbdXRw6cSrsckRE4iLMoClw933BeiVQ0LqBu78NLAP2Bcur7r4xqsnTwWmz/xkVJsOB3cHjG4GjwKA4vYdOu7+smJMNTTyzckfYpYiIxEVcg8bMXjOzD2Iss6PbeWSWybNmmjSzEuASoIhIgNxoZtcHh7/q7pcB1wfLn59nbfeaWbmZlVdXV1/Au+sanyjIZdb4Ap75w3Zq6htDq0NEJF7iGjTuPsPdL42xLAKqzGwoQPBzf4ynuA14x91r3L0G+B0wJXjuvcHP48CzwFXBY/YCI4LnTQP6Awdj1LbA3UvdvTQ/P78r3/Z5mze9hGN1jfz8nZ2h1iEiEg9hnjp7CZgTrM8BFsVoswuYZmZpZpZOZCDAxmB7MECw/3PABzGe9w7gdU/yefknjcjjupLBPLFiO3UNTWGXIyLSpcIMmoeBmWa2BZgRbGNmpWb2ZNDmBWArsA5YA6xx95eJDAx41czWAquJ9GKeCB7zFDDIzCqAB4kxmi0ZzZtezIGaen65ak/YpYiIdClL8v/sJ0RpaamXl5eHWoO784XHVlJ9vJ5l3yojPVXfpRWR5GZmq9y9tL12+muWJMyMeWUl7Dl8kpfXfBR2OSIiXUZBk0RuungIFxXk8tjyrTQ3q6cpIj2DgiaJpKQY86YXs2V/DUs2tve9VBGR7kFBk2RuvmwoIwfmMH9ZBbp+JiI9gYImyaSlpnDftGLW7DnKHyrO+vqPiEi3o6BJQrdfOZwhuZk8uqwi7FJERDpNQZOEMtNSufeGsby97SDv7zocdjkiIp2ioElSd141krycdOYv043RRKR7U9AkqT6ZacydOprXNlaxqfJY2OWIiFwwBU0Smzt1NDkZqTy2XL0aEem+FDRJLC8ng7uuGcXLaz5i58ETYZcjInJBFDRJ7p7rxpCWksLjb+h2zyLSPSloktyQfll8sbSIF1ftofJoXdjliIicNwVNN/CNG4ppcufJFerViEj3o6DpBkYOyuHzE4fy7Hu7OHziVNjliIicFwVNN3F/WQm1p5p4ZuWOsEsRETkvCppu4qLCXGaOL+CZlTuoqW8MuxwRkQ5T0HQj88qKOXqygWff3Rl2KSIiHaag6UauGDmAa0sG8cSK7dQ1NIVdjohIh4QSNGY20MyWmNmW4OeANtr9wMzWm9lGM3vEInLNbHXUcsDMfhS0n2tm1VHH7knsO4u/B8pKqD5ezwur9oRdiohIh4TVo3kIWOru44ClwfYZzGwqcC0wEbgU+CQwzd2Pu/uklgXYCfwq6qHPRx1/Mu7vJMGmFA9i0og8fvzmVhqbmsMuR0SkXWEFzWxgYbC+ELg1RhsHsoAMIBNIB864v7GZfQIYAqyIW6VJxsyYV1bM7kMn+c3afWGXIyLSrrCCpsDdW/5KVgIFrRu4+9vAMmBfsLzq7htbNfsKkR5M9D2PbzeztWb2gpmNiEPtoZtxSQGfKOjL/OUVNDfrds8iktziFjRm9pqZfRBjmR3dLgiJs/5amlkJcAlQBAwHbjSz61s1+wrwXNT2y8Bod58ILOHjXlOs+u41s3IzK6+urr6g9xiWlBRjXlkJH1bV8NrGqvYfICISorgFjbvPcPdLYyyLgCozGwoQ/Nwf4yluA95x9xp3rwF+B0xpOWhmlwNp7r4q6jUPunt9sPkkcOU56lvg7qXuXpqfn9/p95ton5s4lBEDs3l0+VbO7NCJiCSXsE6dvQTMCdbnAItitNkFTDOzNDNLB6YB0afO7uTM3kxLaLW4pVX7HiUtNYX7phWzZvcRVm49GHY5IiJtCitoHgZmmtkWYEawjZmVmlnLSLEXgK3AOmANsMbdX456ji/RKmiAbwbDodcA3wTmxu8thO/2yUUMyc3k0WUVYZciItIm02kXKC0t9fLy8rDLuCAL3tzK/35lE7+eN5UrRsb8OpKISFyY2Sp3L22vnWYG6Ob+7OpR9M9OZ75u9ywiSUpB0831zUxj7tTRLNlQxebK42GXIyJyFgVNDzB36mhyMlJ5bLmu1YhI8lHQ9AAD+mTw1atH8tKaj9h1sDbsckREzqCg6SHuuX4saSkpPP6mrtWISHJR0PQQBf2yuKO0iBfK91B1rC7sckRETlPQ9CDfuGEsjc3NPPXW9rBLERE5TUHTg4wa1IfPXz6Mn72zkyO1p8IuR0QEUND0OPeXFVN7qolnVu4IuxQREUBB0+NcXNiPGZcU8PQfdlBT3xh2OSIiCpqeaN70Yo6ebOC5d3eFXYqIiIKmJ5o8cgBTiwfxxIpt1DU0hV2OiPRyCpoe6oHpJew/Xs+L7+8JuxQR6eUUND3U1OJBXF7Unx+/sY3GpuawyxGRXkxB00OZGfOml7DrUC2/Xbcv7HJEpBdT0PRgMy8pYNyQvsxftpXmZt13SETCoaDpwVJSjHnTi9lcdZylm/aHXY6I9FIKmh7u8xOHUTQgm/9YVoHupioiYVDQ9HBpqSncN62YNbuP8PbWg2GXIyK9kIKmF7jjyiLyczN1u2cRCUWHgsbMis0sM1gvM7NvmlleZ17YzAaa2RIz2xL8HNBGu++b2QfB8uWo/WPM7F0zqzCz580sI9ifGWxXBMdHd6bOniArPZV7rhvDWxUHWL37SNjliEgv09EezYtAk5mVAAuAEcCznXzth4Cl7j4OWBpsn8HMbgYmA5OAq4FvmVm/4PD3gR+6ewlwGLg72H83cDjY/8OgXa/31WtG0S8rjfnLdLtnEUmsjgZNs7s3ArcB/+7u3waGdvK1ZwMLg/WFwK0x2owH3nT3Rnc/AawFPm1mBtwIvBDj8dHP+wJwU9C+V+ubmcbca8eweEMVH1YdD7scEelFOho0DWZ2JzAH+E2wL72Tr13g7i3fJKwECmK0WUMkWHLMbDAwnUhvahBwJAg/gD3A8GB9OLAbIDh+NGjf631t6mhyMlJ5TNdqRCSBOho0XwOmAN9z9+1mNgb4z/YeZGavRV1fiV5mR7fzyLjbs8beuvti4BVgJfAc8DbQJbNEmtm9ZlZuZuXV1dVd8ZRJb0CfDP7sqpG8tOYjdh2sDbscEeklOhQ07r7B3b/p7s8FF+1z3b3dax/uPsPdL42xLAKqzGwoQPAz5jcK3f177j7J3WcCBnwIHATyzCwtaFYE7A3W9xLp9RAc7x+0b/28C9y91N1L8/PzO/Ix9Aj3XD+WVDN+/KZ6NSKSGB0ddbbczPqZ2UDgfeAJM/u3Tr72S0ROxRH8XBTjdVPNbFCwPhGYCCwOekDLgDtiPD76ee8AXnd9U/G0wv5Z3H7lcH65ag/7j9WFXY6I9AIdPXXW392PAV8AfuruVwMzOvnaDwMzzWxL8FwPA5hZqZk9GbRJB1aY2QYio93uirou8x3gQTOrIHIN5qlg/1PAoGD/g8QYzdbbfeOGYhqbmnnqre1hlyIivUBa+00i7YLTW18CvtsVL+zuB4GbYuwvB+4J1uuIjDyL9fhtwFUx9tcBX+yKGnuq0YP78LmJw/jZOzu5v6yYvJyMsEsSkR6soz2afwJeBba6+x/NbCywJX5lSbzdX1bMiVNNLFy5M+xSRKSH6+hggF+6+0R3vz/Y3ubut8e3NImnS4b2Y8YlQ3h65XZO1De2/wARkQvU0cEARWb2azPbHywvmllRvIuT+Jo3vYQjtQ08996usEsRkR6so6fOniYymmtYsLwc7JNubPLIAUwZO4gnVmyjvrFLvp4kInKWjgZNvrs/HUwF0+juzwC958snPdi86cVUHavnV+/vbb+xiMgF6GjQHDSzu4LvtaSa2V3E+BKkdD/XlQxmYlF/Hn9jK41NzWGXIyI9UEeD5i+JDG2uBPYR+SLk3DjVJAlkZswrK2HnwVp+u25f+w8QETlPHR11ttPdb3H3fHcf4u63Ahp11kPMGl9AyZC+zF+2leZmTaIgIl2rM3fYfLDLqpBQpaQY88qK2Vx1nNc3xZxyTkTkgnUmaHr9PV56ks9fPoyiAdk8urwCTQ0nIl2pM0Gjv0Y9SHpqCt+YVsyfdh3hnW2Hwi5HRHqQcwaNmR03s2MxluNEvk8jPcgXryxicN9M5i/X7Z5FpOucM2jcPdfd+8VYct29oxNySjeRlZ7KPdePYcWWA6zZfSTsckSkh+jMqTPpgb569Uj6ZaWpVyMiXUZBI2fIzUpn7tTRvLq+ii1Vx8MuR0R6AAWNnGXutWPITk/lseW63bOIdJ6CRs4ysE8Gf3b1SBat+Yjdh2rDLkdEujkFjcT09evHkmKw4M1tYZciIt2cgkZiKuyfxe2Ti3i+fDf7j9eFXY6IdGOhBI2ZDTSzJWa2Jfg5oI123zezD4Lly1H7f25mm4P9PzGz9GB/mZkdNbPVwfIPiXpPPdE3phXT2NTMU29tD7sUEenGwurRPAQsdfdxwNJg+wxmdjMwGZgEXA18y8z6BYd/DlwMXAZkA/dEPXSFu08Kln+K43vo8cYM7sPNE4fxs7d3crS2IexyRKSbCitoZgMLg/WFwK0x2owH3gxutHYCWAt8GsDdX/EA8B6g20rHybyyYk6camLh2zvCLkVEuqmwgqbA3VtuflIJFMRoswb4tJnlmNlgYDowIrpBcMrsz4HfR+2eYmZrzOx3ZjYhDrX3KpcM7cdNFw/hJ3/Yzon6xrDLEZFuKG5BY2avRV1fiV5mR7cLeiVnTdDp7ouBV4CVwHPA20DrG9vPJ9LrWRFsvw+McvfLgX8H/vsc9d1rZuVmVl5dXX2hb7NXmDe9hCO1DTz33q6wSxGRbihuQePuM9z90hjLIqDKzIYCBD9j3gTF3b8XXGuZSeS2BB+2HDOzfwTyibovjrsfc/eaYP0VID3oDcV67gXuXurupfn5+V30rnumK0cN4OoxA3lyxXbqG1tnvYjIuYV16uwlYE6wPgdY1LqBmaWa2aBgfSIwEVgcbN8DfAq4092box5TaGYWrF9F5P0djOP76DUemF5C5bE6fv3+3rBLEZFuJqygeRiYaWZbgBnBNmZWamZPBm3SgRVmtgFYANzl7i0XCR4ncl3n7VbDmO8APjCzNcAjwFdcd/HqEtePG8xlw/vz2BtbaWxqbv8BIiIB099hKC0t9fLy8rDLSHq//2Af9/3sfR658wpuuVy3IxLp7cxslbuXttdOMwNIh80aX0hxfh/mL9PtnkWk4xQ00mEpKca8shI2VR7n9U0xx2+IiJxFQSPn5ZZJwxiel82j6tWISAcpaOS8pKem8I1pY3l/1xHe3X4o7HJEpBtQ0Mh5+1LpCAb3zeDRZbrds4i0T0Ej5y0rPZW7rxvLii0HWLvnSNjliEiSU9DIBbnrmpHkZqUxf5lu9ywi56agkQuSm5XO3Kmj+f36SrZUHQ+7HBFJYgoauWBfu3YM2empPPaGejUi0jYFjVywgX0yuPOqkSxa/RG7D9WGXY6IJCkFjXTK128YQ4rBEyu2hV2KiCQpBY10ytD+2XzhiiJ+8cfd7D9eF3Y5IpKEFDTSafeVFdPY1MxP3toRdikikoQUNNJpYwb34bOXDeVn7+zkaG1D2OWISJJR0EiXmFdWQk19Iz99e0fYpYhIklHQSJcYP6wfN148hJ/8YTu1pxrbf4CI9BoKGukyD0wv5nBtA794b3fYpYhIElHQSJe5ctRArhozkAVvbuNUo273LCIRChrpUg9ML6HyWB2//tOesEsRkSShoJEudcO4wVw6vB+PLd9KU7NujCYiIQWNmQ00syVmtiX4OaCNdt83sw+C5ctR+58xs+1mtjpYJgX7zcweMbMKM1trZpMT9Z4kwsx4oKyEHQdreWXdvrDLEZEkEFaP5iFgqbuPA5YG22cws5uBycAk4GrgW2bWL6rJt919UrCsDvZ9BhgXLPcCj8XxPUgbPjWhkOL8Prrds4gA4QXNbGBhsL4QuDVGm/HAm+7e6O4ngLXApzvwvD/1iHeAPDMb2lVFS8ekpBj3l5WwqfI4yzbvD7scEQlZWEFT4O4t51UqgYIYbdYAnzazHDMbDEwHRkQd/15weuyHZpYZ7BsORI+t3RPsO4uZ3Wtm5WZWXl1d3ak3I2ebPWkYw/OyeXTZVvVqRHq5uAWNmb0WdX0lepkd3c4jf4XO+kvk7ouBV4CVwHPA20BTcPjvgYuBTwIDge+cb33uvsDdS929ND8//3wfLu1IT03h3hvGsmrnYd7bfijsckQkRHELGnef4e6XxlgWAVUtp7SCnzHPr7j794JrMDMBAz4M9u8LTo/VA08DVwUP2cuZvZ6iYJ+E4MufHMHgvhk8ulw3RhPpzcI6dfYSMCdYnwMsat3AzFLNbFCwPhGYCCwOtltCyohc3/kg6nn/Ihh9dg1wNOoUnSRYVnoqf3ndGN78sJp1e46GXY6IhCSsoHkYmGlmW4AZwTZmVmpmTwZt0oEVZrYBWADc5e4tk2j93MzWAeuAwcC/BPtfAbYBFcATwLxEvBlp213XjCI3K435yyvCLkVEQpIWxou6+0Hgphj7y4F7gvU6IiPPYj3+xjb2O/BA11UqndUvK505U0bz6PIKKvbXUDKkb9gliUiCaWYAibuvXTuazLQUHn9D12pEeiMFjcTdoL6ZfOWTI/nvP+1lz+HasMsRkQRT0EhC3HvDWACeeHNbyJWISKIpaCQhhuVl84XJw/nFH3fzy/LdHD5xKuySRCRBQhkMIL3TX984jne2HeLbL6wlNcX45OgBzBpfyMzxBYwYmBN2eSISJ6bpQaC0tNTLy8vDLqNXcHc+2HuMxRsqWby+is1VxwGYMKwfs8YXMmtCARcX5hL5ipSIJDMzW+Xupe22U9AoaMK0/cAJlgShs2rXYdxhxMDsSOiML6B09EBSUxQ6IslIQXMeFDTJofp4PUs3VrF4QxVvbTnAqaZmBvbJ4KaLhzBrQiHXjxtMVnpq2GWKSEBBcx4UNMmnpr6RNz+sZvH6SpZu2s/xukay01OZ9ol8Zk0o4MaLh5CXkxF2mSK9WkeDRoMBJCn1zUzjs5cN5bOXDeVUYzPvbj/I4vVVLN5Qye/XV5KaYlw9ZiCzxhcwc0Ihw/Oywy5ZRNqgHg3q0XQnzc3Our1HTw8m2LK/BoBLh/fjU+MLmTWhkE8U9NVgApEE0Kmz86Cg6b62VdeweEMVi9dX8qfdR3CHUYNymDW+gFkTCpk8coAGE4jEiYLmPChoeob9x+p4beN+Fm+oZGXFQU41NTOoTwYzLilg1oQCri3RYAKRrqSgOQ8Kmp7neF0Db3xYzavrq1i2aT819Y3kZKRSdlE+s8YXMv2iIfTPSQ+7TJFuTYMBpFfLzUrncxOH8bmJw6hvbOKdbYdYvL6SJRuqeGVdJWkpxjVjBzFrQgEzxxcwtL8GE4jEi3o0qEfTmzQ3O2v2HGHxhipeXV/JtuoTAEws6s+s8QV8akIhJUM0mECkI3Tq7DwoaHqviv01p0ewrd59BIAxg/sEgwkKuGLEAFI0mEAkJgXNeVDQCEDVsTqWbIjMTPD21gM0NDmD+2Yyc/wQZo0vZErxIA0mEImioDkPChpp7VhdA8s3V/Pq+kqWb9rPiVNN9MlIpeyiIcyaUMD0i4fQL0uDCaR3S+qgMbOBwPPAaGAH8CV3Pxyj3feBm4PNf3b354P9K4DcYP8Q4D13v9XMyoBFwPbg2K/c/Z/aq0dBI+dS39jEyq2RmQmWbKjiQE096aktgwkKmXlJAYX9s8IuUyThkj1ofgAccveHzewhYIC7f6dVm5uBvwU+A2QCy4Gb3P1Yq3YvAovc/adB0HzL3T93PvUoaKSjmpudP+0+cvq6zvYDkcEEl4/ICwYTFFAyJLedZxHpGZI9aDYDZe6+z8yGAsvd/aJWbb4NZLn7PwfbTwGvuvt/RbXpB+wERrn7MQWNJJK7B4MJIjMTrNlzFICx+X1O31tnUlGeBhNIj5Xs36MpcPd9wXolUBCjzRrgH83s/wI5wHRgQ6s2twJLW/VyppjZGuAjIqGzvmtLF4kwM8YV5DKuIJcHppew7+hJXgsGEzy5YhuPv7GV/NxMZo4vYNb4AqYWDyYjTXdPl94nbj0aM3sNKIxx6LvAQnfPi2p72N0HxHiO7wJfBKqB/cAf3f1HUcd/Bzzp7i8G2/2AZnevMbPPAv/P3ce1Ud+9wL0AI0eOvHLnzp0X+E5Fzna0toFlmyPT4SzfXE3tqSZyM9Mou3gIs8YXUHZRPrkaTCDdXLc/dRbjMc8CP3P3V4LtwcBmYLi717XxmB1AqbsfONdz69SZxFNdQxMrtx44PZjg4IlTpKcaU4sHR2YmuKSAIf00mEC6n2QPmv8DHIwaDDDQ3f+uVZtUIM/dD5rZROBZYJK7NwbH7wOmuPucqMcUAlXu7mZ2FfACkes353yTChpJlKZm50+7Dp+emWDnwVoArhiZx6zxhXxqQgFj8/uGXKVIxyR70AwC/gsYSeRi/pfc/ZCZlQL3ufs9ZpYFvB885Fiwf3XUcywHHnb330ft+yvgfqAROAk86O4r26tHQSNhcHc+rKph8fpKFm+oYt3eyGCCkiF9T9/mYOLw/hpMIEkrqYMm2ShoJBnsPdIymKCSd7YdoqnZGZKbyUWFuQzPy6ZoQDbDB2RTNCCH4XnZFPTL0r12JFQKmvOgoJFkc6T2FMs272f55mp2HKxl7+GTHKipP6NNWooxNC+LorycIICyg0DKoWhANoX9s0hP1Sg3iZ9kH94sIueQl5PBbVcUcdsVRaf31TU0sffISfYcPsnewyfZc7j29PZbWw5QdbyO6P83phgU9ss6oxcU3SsalpdFZprmbpP4U9CIdBNZ6akU5/eluI3BAvWGVDXGAAAKX0lEQVSNTVQerTsjiPYEQfTe9kPsO3qS5lYnMIbkZsYMohEDshmel0N2hoJIOk9BI9JDZKalMmpQH0YN6hPzeGNTM5XHooPoJHuP1LLn8EnW7jnC7z/YR0PTmUk0qE/GWaflhudlUzQwsq3vAklHKGhEeom01JTg+k1OzONNzU718fozTsntCXpGmyqPs3Tjfuobm894TP/s9JgDFYqCcOqfna6byImCRkQiUlOMwv5ZFPbPItbVXXfnQM2pM4Ko5RTdjoMneKviALWnms54TN/MtDOC6HSvKAiiQX0yFES9gIJGRDrEzMjPzSQ/N5MrRp41YxTuzpHahjNOye05fYruJO/tOMTxusYzHpOVnsLwvGyGByPlontDRQNyyO+bqe8R9QAKGhHpEmbGgD4ZDOiTwWVF/WO2OXqygb1B8Ow5XBt1regk6/Yc4XBtwxntM1JTGJYXjJxrPYx7YA4FuZmkaQh30lPQiEjC9M9Op392OuOH9Yt5/ER9I3uPnD1qbu/hkyzdtP+s7xKlphhD+2edcUpuYE46/bLTyc1Kp19WWuRndhr9stPpm5GmHlIIFDQikjT6ZKbxiYJcPlEQ++ZxLd8litUrWrn1AJXHzvwuUWtmketG/bLSyc2KhE+/rNbbH6/nBsei13Wrh/OnoBGRbqO97xI1NDVzvK6RYycbOFbXcHr9eF0jx+oagv2NZxzbe6SOTXXHI+3qG88ZVJEaUlr1llr1nIJj0eEUfSwnI7XXDYBQ0IhIj5GemsLAPhkM7JNxQY9vbnZOnGrkWF0jx+saOHYy+HnG+sfHjtU1cPRkA3sO1UYC7GQDp5qaz/kaqSlGblbax72ldntTZ7bLzUrrdtelFDQiIoGUFCM36IFA9gU9R11D0+keVHu9q5b1XYdqT/e2auob232NPhmpp3tJue30oKJDrKVdZlpKQntVChoRkS6UlZ5KVnoq+bmZF/T4pmanpiWMztGTOt3jqm/gQM0pth04cTrMGlvPNdRKeqqdvvb01atHcs/1Yy+o1o5S0IiIJJHUFKN/Tjr9cy5seh93p66huc1rUq17VxcaiOdDQSMi0oOYGdkZqWRnpFKQJLcI715XlEREpNtR0IiISFwpaEREJK4UNCIiElehBY2ZfdHM1ptZs5m1ec9pM/u0mW02swozeyhq/xgzezfY/7yZZQT7M4PtiuD46Pi/GxERaUuYPZoPgC8Ab7bVwMxSgUeBzwDjgTvNbHxw+PvAD929BDgM3B3svxs4HOz/YdBORERCElrQuPtGd9/cTrOrgAp33+bup4BfALMt8pXWG4EXgnYLgVuD9dnBNsHxm6y3TSwkIpJEkv0azXBgd9T2nmDfIOCIuze22n/GY4LjR4P2ZzCze82s3MzKq6ur41S+iIjE9QubZvYaUBjj0HfdfVE8X7s97r4AWABgZtVmtvMCn2owcKDLCus6yVoXJG9tquv8qK7z0xPrGtWRRnENGnef0cmn2AuMiNouCvYdBPLMLC3otbTsj37MHjNLA/oH7c9VZ/6FFmhm5e7e5mCGsCRrXZC8tamu86O6zk9vrivZT539ERgXjDDLAL4CvOTuDiwD7gjazQFaekgvBdsEx18P2ouISAjCHN58m5ntAaYAvzWzV4P9w8zsFTh9jeWvgFeBjcB/ufv64Cm+AzxoZhVErsE8Fex/ChgU7H8QOD0kWkREEi+0STXd/dfAr2Ps/wj4bNT2K8ArMdptIzIqrfX+OuCLXVrsuS1I4Gudj2StC5K3NtV1flTX+em1dZnOKomISDwl+zUaERHp5hQ0HdTWVDhRx0OZ+qYDdc0Nhm+vDpZ7ElTXT8xsv5l90MZxM7NHgrrXmtnkJKmrzMyORn1e/5CAmkaY2TIz2xBMy/Q3Mdok/PPqYF0J/7yC180ys/fMbE1Q2/+K0Sbhv5MdrCus38lUM/uTmf0mxrH4flburqWdBUgFtgJjgQxgDTC+VZt5wOPB+leA55OkrrnAf4Twmd0ATAY+aOP4Z4HfAQZcA7ybJHWVAb9J8Gc1FJgcrOcCH8b4d0z459XBuhL+eQWva0DfYD0deBe4plWbMH4nO1JXWL+TDwLPxvr3ivdnpR5Nx8ScCqdVmzCmvulIXaFw9zeBQ+doMhv4qUe8Q+R7UUOToK6Ec/d97v5+sH6cyAjL4a2aJfzz6mBdoQg+h5pgMz1YWl9wTvjvZAfrSjgzKwJuBp5so0lcPysFTce0NRVOzDZ+jqlvQqgL4PbgdMsLZjYixvEwdLT2MEwJTn38zswmJPKFg1MWVxD5n3C0UD+vc9QFIX1ewamg1cB+YIm7t/mZJfB3siN1QeJ/J38E/B3Q3MbxuH5WCpqe72VgtLtPBJbw8f9aJLb3gVHufjnw78B/J+qFzawv8CLwt+5+LFGv25526grt83L3JnefRGRmkKvM7NJEvfa5dKCuhP5OmtnngP3uviqer3MuCpqOaWsqnJhtrINT3ySiLnc/6O71weaTwJVxrqmjOvKZJpy7H2s59eGR73Clm9ngeL+umaUT+WP+c3f/VYwmoXxe7dUV1ufVqoYjRGYK+XSrQ2H8TrZbVwi/k9cCt5jZDiKn1280s5+1ahPXz0pB0zExp8Jp1SaMqW/aravVefxbiJxnTwYvAX8RjKa6Bjjq7vvCLsrMClvOTZvZVUR+R+L6xyl4vaeAje7+b200S/jn1ZG6wvi8gtfKN7O8YD0bmAlsatUs4b+THakr0b+T7v737l7k7qOJ/I143d3vatUsrp9VaDMDdCfu3mhmLVPhpAI/cff1ZvZPQLm7v0TkF/I/LTL1zSEi/6DJUNc3zewWoDGoa2686wIws+eIjEgabJGphv6RyIVR3P1xIrM9fBaoAGqBryVJXXcA95tZI3AS+EoC/sNwLfDnwLrg3D7A/wBGRtUVxufVkbrC+LwgMiJuoUVujphCZHqq34T9O9nBukL5nWwtkZ+VZgYQEZG40qkzERGJKwWNiIjElYJGRETiSkEjIiJxpaAREZG4UtCIJICZNUXN1rvaYsy03YnnHm1tzEYtkgz0PRqRxDgZTEsi0uuoRyMSIjPbYWY/MLN1wX1MSoL9o83s9WDixaVmNjLYX2Bmvw4msVxjZlODp0o1sycscg+UxcG30kWSgoJGJDGyW506+3LUsaPufhnwH0Rm2YXIBJULg4kXfw48Eux/BHgjmMRyMrA+2D8OeNTdJwBHgNvj/H5EOkwzA4gkgJnVuHvfGPt3ADe6+7ZgAstKdx9kZgeAoe7eEOzf5+6DzawaKIqalLFlCv8l7j4u2P4OkO7u/xL/dybSPvVoRMLnbayfj/qo9SZ0/VWSiIJGJHxfjvr5drC+ko8nNvwqsCJYXwrcD6dvsNU/UUWKXCj9r0ckMbKjZkAG+L27twxxHmBma4n0Su4M9v018LSZfRuo5uPZmv8GWGBmdxPpudwPhH57BZFz0TUakRAF12hK3f1A2LWIxItOnYmISFypRyMiInGlHo2IiMSVgkZEROJKQSMiInGloBERkbhS0IiISFwpaEREJK7+P/p97lmV6RRsAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1233e5e10>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Model trained.\n",
      "\n",
      "Test Accuracy: 0.7781438769906645\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:67: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n"
     ]
    }
   ],
   "source": [
    "model = LogisticRegression(len(text.vocab), 2, 10)\n",
    "model.train(train_iter = train_iter, val_iter = val_iter, num_epochs = 5,  learning_rate = 1e-3)\n",
    "model.test(test_iter)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
